from sklearn.svm import SVC
from sklearn.datasets import load_digits
from sklearn import metrics
from matplotlib import pyplot as plt
import random

digits = load_digits()

# 将数据分开保存
images = digits.images
labels = digits.target
print(images)
print(images.shape)
# 此时 images 是三维的（1797 * 8 * 8），即： 1797个8 * 8的矩阵，
print(images.ndim)
print(labels)

# 来个图片看一下结构
fig = plt.figure(figsize=(10, 10))
fig.subplots_adjust(left=0, right=1, bottom=0, top=1, hspace=0.05, wspace=0.05)

# 显示前100个看看
for i in range(100):
    sub_image = fig.add_subplot(10, 10, i + 1, xticks=[], yticks=[])
    sub_image.imshow(digits.images[i], cmap=plt.cm.binary, interpolation='nearest')
    sub_image.text(0, 9, str(digits.target[i]))

plt.show()

# 将8*8的图片转换为一维向量
cnt = len(images)
images_vector = images.reshape((cnt, -1))
print(images_vector)
print(images_vector.shape)

# 随机选择训练集和测试集
sample = list(range(cnt))
test_size = int(cnt * 0.3)
random.shuffle(sample)
train, test = sample[test_size:], sample[:test_size]

X_train, Y_train = images_vector[train], labels[train]
X_test, Y_test = images_vector[test], labels[test]

# 使用rbf核函数
classifier = SVC(kernel='rbf', C=1.0, gamma=0.001)
classifier.fit(X_train, Y_train)
print(classifier)

prediction = classifier.predict(X_test)
print(prediction)
print(metrics.classification_report(Y_test, prediction))
